Skip to content

Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101

Merged
tbraun96 merged 10 commits into
Avarok-Cybersecurity:mainfrom
camerono:pr/ep-batched-decode
Jun 3, 2026
Merged

Lift max_batch_size=1 under EP=2 via head↔worker slot multiplexing (#99)#101
tbraun96 merged 10 commits into
Avarok-Cybersecurity:mainfrom
camerono:pr/ep-batched-decode

Conversation

@camerono

Copy link
Copy Markdown
Contributor

Implements #99 — lifts the max_batch_size = 1 clamp under --ep-size 2 by multiplexing the head↔worker protocol, and makes the batched multi-sequence decode path it unlocks both correct and fast.

Behind ATLAS_EP_PROTOCOL=v2 so the wire change is opt-in; v1 behaviour is unchanged when the flag is unset.

Protocol (the design from #99, landed)

SequenceState.slot_idx already identifies a sequence's SSM-pool slot and the whole forward pass indexes by it — the protocol just couldn't express which slot the head was operating on, so the worker only ever used slot 0 and serve.rs was forced to clamp max_batch_size = 1. The change plumbs the existing slot_idx through as the seq_id:

  1. ep_broadcast_seq_and_cmd / ep_recv_seq_and_cmd helpers (preamble is skipped when v2 is off).
  2. Head emits the seq_id preamble at every EP broadcast site.
  3. Worker dispatches multi-slot (Vec<Option<SequenceState>> instead of a singleton).
  4. Explicit alloc-slot / free-slot command codes.
  5. Drop the world_size > 1 clamp under v2; pre-allocate all worker slots and skip retire-compaction (the worker keeps slots in place, keyed by slot_idx).

Two head-side fixes the protocol exposed

These were in the model layer, not the wire protocol:

  • decode_batch_dispatch's EP branch wrote single-row logits to row 0 per sequence, so every sequence sampled the last one's logits. Fixed by staging each sequence's logits row (later, a d2d copy; finally subsumed by the batched path).
  • Broadcasting all seq-id preambles up front put the head's comm-stream op order out of step with the worker's per-layer all-reduces, deadlocking NCCL. Fixed by interleaving the broadcasts inside the dispatch.

Batched-decode MoE: use the fused batch kernels, not the grouped-GEMM

The SSM layers' batched MoE was dead code behind an early return. Reviving it with forward_prefill (the prefill grouped-GEMM) was a ~60% regression — at decode batch sizes the per-expert M is ≈1, and the 256-expert sort/permute/ptr-table overhead dominates, ×36 SSM layers. The dispatch that works:

  • N=2/3 → fused forward_k2 / forward_k3 (one batched all-reduce, no per-token launch overhead)
  • N≥4 → per-token MoE loop

The SSM mixer stays per-sequence (it carries recurrent state); only the stateless MoE is hoisted out and batched. The attention layers had the same forward_prefill-at-N≥4 trap, fixed the same way.

Measured (Qwen3.5-122B-A10B-NVFP4, EP=2 on 2× GB10, MTP on)

Change Effect
Gate lift 4-concurrent burst now batches instead of serializing — removes the batch=1 tail-spike from #99
SSM MoE dispatch (forward_k2) SSM decode step 44 ms → 35 ms at N=2 (~15–20%)
Attention MoE dispatch attention block ~40 ms → ~24 ms at N=4

One expectation to set: lifting the gate buys tail latency and admission, not aggregate throughput at low concurrency. Batched vs serialized at N=2 is ~equal (~32 tok/s either way) — the decode step is memory-bandwidth + inter-node-NCCL bound, and two concurrent tokens share almost no weight loads. Output is coherent and cross-sequence-isolated at N=2 and N=4. Full reasoning, including why CUDA graphs and single-host don't help this model, is in docs/adr/0011-ep-batched-decode-optimization.md.

camerono added 10 commits May 27, 2026 19:11
…ocol

Foundation for atlas#99 (lift max_batch_size=1 under --ep-size N). Adds two
helpers that wrap an optional seq_id preamble around the existing cmd
broadcast:

  ep_broadcast_seq_and_cmd(seq_id, cmd, v2)
  ep_recv_seq_and_cmd(v2) -> (seq_id, cmd)

When v2 is false, the preamble is skipped and the wire shape is byte-
identical to today's protocol — head and worker built before this change
continue to interoperate. When v2 is true, a slot identifier (which will be
SequenceState.slot_idx in subsequent commits) precedes each command so the
worker can route slot-bound dispatch into the right SsmStatePool slot.

No callers yet — this commit is purely additive. Commits 2-5 will:
  - emit the preamble at head-side broadcast sites (scheduler, verify_k*)
  - dispatch multi-slot on the worker (loop in build.rs, slots vec on
    ep_worker_step_impl)
  - split the legacy "free+realloc" 0xFFFFFFF1 into explicit alloc-slot
    and free-slot commands
  - flip the world_size > 1 clamp in serve.rs and switch the scheduler to
    round-robin dispatch over active sequences

v2 wiring lives at the call site for now (caller passes the flag
explicitly, no implicit env-var reads inside the helper) to keep this
commit testable as pure logic and to align with AGENTS.md's PCND
invariant.
Wires up the ep_broadcast_seq_and_cmd helper added in 21e2130 by switching
every command-kickoff site in the scheduler from ep_broadcast_cmd(cmd) to
ep_broadcast_cmd_for_seq(slot_idx, cmd). Follow-on broadcasts within the
same command (chunk metadata, additional tokens, accept/reject result)
keep using ep_broadcast_cmd unchanged — they ride the slot context the
worker picked up from the preamble.

Plumbing:

  * `TransformerModel.ep_protocol_v2: bool` (types.rs) reads
    ATLAS_EP_PROTOCOL env at construction (impl_a1.rs).
  * `Model::ep_broadcast_cmd_for_seq(seq_id, cmd)` trait method added in
    traits/model.rs (default no-op); TransformerModel override in
    trait_impl/mod.rs routes through the helper using its v2 flag.
  * `Model::ep_protocol_v2() -> bool` trait method added (default false).

Call sites migrated (kickoff cmd codes 0xFFFFFFF0..F4 + token-level
decode kickoffs that broadcast a.last_token):

  scheduler/verify_k2_step.rs    : K=2 marker
  scheduler/verify_k3_step.rs    : K=3 marker
  scheduler/verify_k4_step.rs    : K=4 marker
  scheduler/prefill_a_step.rs    : chunk-0 prefill + 2 error-recovery
  scheduler/prefill_b_step.rs    : single-shot prefill + 1 error-recovery
  scheduler/phase_continue_prefills/run_standard.rs : chunked prefill
  scheduler/phase_promote_prefills.rs : error-path free
  scheduler/lifecycle.rs         : finish_sequence + send_error + swap
  scheduler/mod.rs               : scheduler shutdown drain + shutdown cmd
  scheduler/spec_step.rs         : self-spec + ngram bootstrap + ngram K=2
  scheduler/mtp_step.rs          : MTP bootstrap decode

For sites with `&mut ActiveSeq` in scope we use `a.seq.slot_idx as u32`;
for sites with a bare `&mut SequenceState` (prefill paths,
phase_promote_prefills error path) we use `seq.slot_idx as u32`. The
scheduler shutdown command (0xFFFFFFFF) isn't slot-bound; seq_id is
broadcast as 0 by convention and the worker ignores it.

Wire effect:

  * ATLAS_EP_PROTOCOL unset / != "v2": ep_protocol_v2 is false, the
    helper skips the preamble, broadcast shape is byte-identical to today.
  * ATLAS_EP_PROTOCOL=v2: head broadcasts (slot_idx, cmd) as two
    sequential u32s before any follow-on data. Worker doesn't yet
    consume the preamble — that's commit 3.

Single-rank (no comm) and v1 multi-rank both keep working through this
commit. v2 isn't end-to-end until commit 3 lands the worker dispatch.
Closes out the v2 protocol end-to-end by teaching the worker to maintain
N parallel `SequenceState` slots and route incoming commands by the
seq_id preamble emitted in commit 31e44c6.

Refactor in `model/impl_a2.rs`:

  * `ep_worker_step_impl` now takes `&mut [Option<SequenceState>]`. It
    reads `(seq_id, cmd)` via `ep_recv_seq_and_cmd` and handles the two
    slot-independent codes inline: shutdown (`0xFFFFFFFF`, applies to
    the whole worker, seq_id ignored) and alloc-slot (`0xFFFFFFF1`,
    frees any prior occupant of `slots[seq_id]` then allocates a fresh
    `SequenceState`).
  * Per-command dispatch (prefill/decode/verify K=2/3/4) moves into a
    new `ep_worker_dispatch_cmd(cmd, seq)` helper. The body is verbatim
    from before — only the prelude (slot lookup + alloc handling) moved.
  * Under v2 we defensively `bail!` if the worker's freshly-claimed
    SSM-pool slot doesn't equal the seq_id the head broadcast. Head and
    worker both pop from a `free_slots: Mutex<Vec<usize>>` in matched
    order so this should always hold — the check is here so we catch
    the invariant breaking loudly rather than corrupting KV.

Trait + dispatch chain:

  * `Model::ep_worker_step` signature is now `(&mut
    [Option<SequenceState>])`. Default impl returns `Ok(true)` as before.
  * `ep_worker_step_dispatch` (ep_misc.rs) and the TransformerModel
    trait impl (mod.rs) forward the slots slice through.

Worker entry in `spark-server/src/main_modules/serve_phases/build.rs`:

  * Allocates a `Vec<Option<SequenceState>>` of `args.max_batch_size`
    `None`s. Pre-allocates slot 0 with `alloc_sequence()` so v1 (which
    never issues an explicit alloc command before its first decode)
    keeps working without head-side changes.
  * On shutdown, walks every slot and frees occupants.

Backward compatibility:

  * With `ATLAS_EP_PROTOCOL` unset/!=`v2`: `ep_recv_seq_and_cmd` returns
    `seq_id=0` regardless of what the head sends (helper skips the
    preamble read entirely). `slots[0]` is pre-allocated. Every command
    routes to slot 0. Wire shape + worker semantics are byte-identical
    to the pre-PR singleton path.
  * With `ATLAS_EP_PROTOCOL=v2`: worker reads the preamble, routes
    correctly, alloc commands fill in additional slots as the head
    requests them.

The gate clamp in `serve.rs:307-314` still forces `max_batch_size=1`
under EP. Lifting that is commit 5 — at which point the scheduler can
actually drive multiple active sequences. Without commit 5, v2 is
inert because the head never has more than one active sequence to
preface with a nonzero seq_id.

File size: `impl_a2.rs` grew from 359 to 417 LoC. Under the 500-line CI
guideline; the existing match arms account for most of the bulk and
weren't worth splitting into a separate module for this PR.
Two changes that pair conceptually: the worker pre-allocates every slot
in its slots Vec at startup (not just slot 0), and head-side compaction
in retire_finished_sequences is skipped when the model reports
ep_protocol_v2().

Pre-allocation rationale. Under v1 the worker only ever saw slot 0, so
init-time pre-alloc of slot 0 sufficed. With the v2 protocol layer in
place the head broadcasts a per-cmd seq_id preamble — and a future
`max_batch_size > 1` lift will route prefill commands to slots > 0
without a preceding alloc broadcast (head's prefill_step claims a fresh
slot via alloc_sequence + broadcasts 0xFFFFFFF0; the 0xFFFFFFF1 alloc
command only fires on lifecycle events). Without all slots pre-claimed
on the worker, that future routing would bail with "cmd 0xfffffff0 for
unallocated slot N". Pre-allocating every slot up-front matches what
the head's own SSM pool does (`(0..max_slots).rev().collect()` + `pop`)
so sequential alloc_sequence() calls on both ranks return the same
slot_idx order, keeping the new_seq.slot_idx == slot_idx defensive
check in ep_worker_step satisfied.

Skip-compaction rationale. retire_finished_sequences compacts the
active vec so position == slot_idx, then tags the retired entry with
usize::MAX as a do-not-double-free sentinel. Under v2, two things
break: (1) moving SSM states on the head only would leave the worker's
mirror at the original slot since the worker is keyed on slot_idx not
active-set position, and (2) usize::MAX cast to u32 is 0xFFFFFFFF —
the v1 shutdown command — which the worker would read as a real seq_id
and trip its bounds check on the next preamble broadcast. Pre-allocated
slots stay valid in place across the swap_remove, and the per-slot
CUDA graph cache stays warm because the seq never moved.

No behavior change under v1. With max_batch_size=1 (the standing EP
clamp at serve.rs:307-314), the slots Vec has length 1, pre-allocation
claims exactly slot 0, and active.len() never exceeds 1 so the
compaction branch is unreachable. ep_protocol_v2() defaults false, so
skip_compaction is false on v1, identical legacy behavior.
Bench-validated end-to-end on 2× GB10 (GB10 × 2, EP=2,
qwen3.5-122b-a10b NVFP4, MTP nvfp4 speculative): N=4 and N=8 concurrent
decode now correct and coherent. Single-seq baseline unchanged.

Three coordinated changes:

1. decode_a2.rs — decode_batch_dispatch's EP branch was a known dead
   path under v1 (max_batch_size=1 clamp). Three things needed to be
   true at once to make it correct for N>1 EP:

   a) Per-layer NCCL allreduces must align in size and order with the
      worker's matching allreduces. The worker runs decode() per slot
      in ep_worker_step, so the head must also run decode() per seq —
      no batched decode_multi_seq under EP.

   b) The order of ops submitted to the comm matters. Worker submits
      per seq: broadcast(preamble), then N_layer all_reduces. Head
      previously batched all N preambles up-front from the scheduler,
      which made head submit [B,B,B,B,AR,AR,...] while worker
      submitted [B,AR,...,AR,B,AR,...,AR,...]. NCCL collectives match
      by submission position; mismatched positions deadlocked the
      comm. Observed empirically as "NCCL broadcast took 51.1s" on the
      worker followed by stale comm reads. Now the broadcasts live
      inside decode_batch_dispatch's EP branch, interleaved with each
      self.decode() — both ranks submit [B,AR,...,AR,B,AR,...] in
      matching order.

   c) self.decode() writes single-row logits to row 0 of the logits
      buffer on every call. Looping N decodes overwrites the buffer
      so process_decode_logits ends up sampling N rows of the last
      seq's logits. Stage each seq's row to host immediately after
      its decode() (the buffer is fresh within the scope of one
      decode()) then upload the assembled [n, vocab] back to the
      logits buffer before returning. Same pattern as the existing
      MLA per-seq fallback below.

   And one stream subtlety: decode_dispatch overrides the caller's
   `stream` parameter and uses `self.gpu.default_stream()` internally
   for its forward-pass kernels. Issuing the per-seq D2H copy on the
   scheduler's stream=0 (legacy NULL stream) landed it on a different
   CUDA stream than the GEMV that wrote the logits. The copy could
   read stale logits even though both streams "should" have synced.
   Use self.gpu.default_stream() throughout the EP path so the copy
   queues onto the same stream as the GEMV writes.

   Also broadcast in the n=1 short-circuit so the scheduler is fully
   relieved of decode-broadcast responsibility (the EP n>1 branch
   couldn't move broadcasts inline without the n=1 path also handling
   its own).

2. decode_step.rs — remove the per-token broadcast loop. The
   responsibility moved into decode_batch_dispatch, so step_decode_only
   no longer needs to know about EP at the cmd-broadcast layer.

3. serve.rs — honor --max-batch-size under EP when ep_protocol_v2()
   returns true. v1 still clamps to 1 so existing deployments are
   byte-identical without ATLAS_EP_PROTOCOL=v2.

Operational result on the motivating workload (nemoclaw 122B EP=2,
qwen3.5-122b-a10b NVFP4 + MTP nvfp4 speculative, --max-batch-size 4):

  Wall-clock (concurrent users, max_tokens=80 each):
    v1 max_batch=1 k=4: 2 of 4 tail-spiked to 605s (head-of-line)
    v2 max_batch=4 N=4: 7.01s  (all 4 coherent)
    v2 max_batch=4 N=8: 13.12s (all 8 coherent, 4 in-flight + 4 queued)

  Per-seq throughput is lower than the single-seq baseline (5-10 tok/s
  vs ~38 tok/s) because the head still runs N sequential forward passes
  and the per-step host-staging adds overhead. The user-visible win is
  tail-latency elimination, not aggregate throughput. A true batched-EP
  forward pass — Option A in PR_NOTES — remains the follow-up for the
  throughput multiplier; this PR is the structural prerequisite that
  also delivers the head-of-line fix today.
Eliminates the host round-trip in decode_batch_dispatch's EP branch.
Previously: each per-seq decode() wrote row 0 of the logits buffer, the
host code copied row 0 to a staging Vec, then uploaded the assembled
[n, vocab] back to logits. Two PCIe transfers per seq plus one final
upload.

Now: iterate in reverse, decode each seq (still writing to row 0), then
issue a device-to-device copy from row 0 to the target row i. For i=0
(processed last in the reverse iteration), no copy needed — row 0
already holds seq 0's logits. Stays on GPU memory throughout.

Bench on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2,
--max-batch-size 4, MTP nvfp4 speculative):

  N=4:  7.01s -> 5.89s (-16%)
  N=8: 13.12s -> 11.43s (-13%)

The eventual true batched-EP forward pass (one decode_multi_seq call
per step instead of N sequential decode()s) subsumes this — N rows
get written directly by the lm_head GEMV loop and no staging is
needed at all. Until that lands, the d2d cuts the visible bench wall
time without touching kernel correctness.
…i-seq decode

Two coordinated changes that together make the multi-seq decode path a
first-class entry point under EP, using the same kernel set that prefill
already uses.

1. Batched-EP protocol (decode_a2.rs + impl_a2.rs)

   Reintroduces `0xFFFFFFE0` — the batched-decode command code originally
   tried (and reverted) in the foundation cycle. Head broadcasts
   `(seq_id=0, 0xFFFFFFE0)` preamble + N + seq_ids[N] + tokens[N] via the
   new `ep_broadcast_decode_batch_dispatch` helper. Worker matches with
   `ep_worker_decode_batch` (handled BEFORE the slot_idx lookup in
   `ep_worker_step_impl`, like shutdown), reads the payload, builds an
   in-order `Vec<&mut SequenceState>` from the addressed slots, and
   dispatches into the shared compute path.

   The shared compute path is `decode_batch_compute_main` — extracted
   from `decode_batch_dispatch`'s former non-EP main branch so both
   ranks now reach it. The head's EP branch broadcasts the protocol
   primitive, then calls it. The worker's batched handler also calls
   it. No host-staging, no per-seq broadcast loop — one batched forward
   pass per step per rank.

   Comm-stream submission order per decode step is identical on both
   ranks: `B(0) B(0xFFFFFFE0) B(N) B*N(seq_ids) B*N(tokens)` then per
   MoE layer N × `comm.all_reduce(h*elem)` from the per-token loop
   inside the batched MoE path (Avarok-Cybersecurity#2 below).

2. Grouped-MoE in multi-seq decode (trait_decode_multi_seq.rs +
   qwen3_attention/trait_impl/multi_seq/ffn.rs)

   The multi-seq decode path on both qwen3_ssm and qwen3_attention
   layers previously called `self.ffn.forward(normed2_i, ctx, stream)`
   inside a per-token loop — N × (gate GEMV + top_k expert GEMVs +
   weighted sum) per MoE sublayer.

   Refactor the loop into three phases:
     A: per-token residual_add_rms_norm, laying out `norm_output[0..n]`
        as a contiguous [N, h] MoE input
     B: ONE call to `self.ffn.forward_prefill(norm_base, n, ctx, stream)`
        — the grouped-GEMM path that the prefill scheduler already uses.
        Sort tokens by expert, one grouped gate+up GEMM, SiLU, one
        grouped down GEMM, unpermute.
     C: per-token residual_add reading `moe_output[i]`.

   Bug-Avarok-Cybersecurity#6 invariant preserved: SSM outputs are still copied to
   `ssm_out_safe` before Phase A so the batched MoE's writes to
   `moe_output[0..n]` don't clobber the SSM outputs the rms_norm reads.

Perf on 2× GB10 (qwen3.5-122b-a10b NVFP4, EP=2,
--max-batch-size 4, MTP nvfp4 speculative, greedy bench warm):

  N=1: ~38 tok/s aggregate
  N=2: ~29 tok/s aggregate
  N=4: ~38 tok/s aggregate (per-seq ~10 tok/s)
  N=8: ~26 tok/s aggregate (4 in-flight + 4 queued)

Same throughput as the d2d-only path the previous commit shipped —
the architectural ceiling on this hardware at low N is roughly the
single-seq decode rate. The per-seq drop is the expected cost of
sharing compute across N tokens; aggregate stays at the GPU's
weight-load-bandwidth ceiling.

What this PR's structural changes enable that d2d-only couldn't:
- A single-call multi-seq decode entry point that uses the same MoE
  kernels (`moe_w4a16_grouped_gemm_ptrtable`, `moe_topk_softmax_batched`,
  `moe_sort_by_expert`, `moe_unpermute_reduce_indexed`) as prefill —
  one less code path to keep in sync, one less code path to optimize.
- A clear hook for future kernel-level wins: a true batched
  `comm.all_reduce(N*h*elem)` instead of N per-token `comm.all_reduce(h)`
  would land cleanly here without changing the dispatch.
- EP=4 / higher-N regimes where expert reuse becomes meaningful
  (N >> 256/top_k) will exercise the same path; today's grouped-GEMM
  kernel is already in production for prefill and known-correct.

No behavior change under v1 (ATLAS_EP_PROTOCOL unset or != "v2"):
the EP n>1 path is still unreachable because `serve.rs` clamps
max_batch_size=1. Under v2 with N=1, the n=1 short-circuit fires;
the batched code path only sees N>1 when the gate is flipped.
The SSM multi-seq decode path delegated every sequence to the full
single-token decode(), running N independent single-token MoE forwards
(N x top_k expert GEMVs + N per-token all_reduces under EP). The batched
grouped-MoE code meant to replace it sat behind an early return,
unreachable since the bug-Avarok-Cybersecurity#6 buffer-aliasing debugging.

Replace the delegation with a per-seq SSM mixer loop (conv1d/GDN
recurrent state is inherently per-sequence, so the proven single-token
kernels stay) feeding a batch-dispatched MoE:
  - N=2/3: fused forward_k2/k3 -- one batched all_reduce, no per-token
    launch overhead. SSM decode-step 44->35ms at N=2 on GB10
    (qwen3.5-122b-a10b NVFP4, EP=2).
  - N>=4: per-token MoE loop. The generic grouped-GEMM (forward_prefill)
    is a net loss for this 256-expert MoE at small batch -- per-expert M
    is ~1 and the sort/permute/ptr-table overhead, paid once per layer
    across 36 SSM layers, pushed the SSM step to ~140ms vs ~88ms
    per-token. forward_prefill is declined here until a true batched-EP
    MoE kernel exists.

Buffer safety: each per-seq mixer writes its MoE input to norm_output[i]
(distinct per-seq offset); ssm_forward never touches norm_output and its
ssm_out is consumed within the same iteration, so nothing survives across
sequences and the old aliasing cannot recur.

Validated coherent + cross-seq isolated at N=2 and N=4. Removes the
unreachable batched-decode block.
Mirror of the SSM-layer fix (parent commit): the attention layers'
multi-seq FFN used the generic grouped-GEMM (forward_prefill) for the
N>=4 branch, which is a net loss for this 256-expert MoE at small batch
-- per-expert M ~1 and the sort/permute/ptr-table overhead dominates.

Replace it with the per-token MoE loop (identical to decode()'s MoE).
N=2/3 keep the fused forward_k2/k3 branches. Measured on GB10
(qwen3.5-122b-a10b NVFP4, EP=2, N=4): attention decode-block ~40->~24ms,
full step ~132->~122ms, no regression. This also makes the N=2/3 MLA
fallback path (force_seq_ffn) avoid the batched-MoE kernels it was never
safe with.

Validated coherent + cross-seq isolated at N=4.
@github-actions

github-actions Bot commented May 29, 2026

Copy link
Copy Markdown

All contributors have signed the CLA. Thank you!
Posted by the CLA Assistant Lite bot.

@camerono

Copy link
Copy Markdown
Contributor Author

I have read the CLA Document and I hereby sign the CLA

@tbraun96

tbraun96 commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

@camerono this is a genuinely strong PR, thank you. I went through it end to end and it is static-verified on my side: the workspace compiles, spark-model (65) and spark-server (489) test suites are green, and the v2 gating is correctly additive. ep_protocol_v2 is read once from ATLAS_EP_PROTOCOL at construction, defaults false, and every wire-format change (the seq_id preamble, slot pre-alloc, skip-compaction, honoring max_batch_size) is behind it, so the v1 default path is byte-for-byte unchanged. That is exactly the backward-compat property this needed.

What I liked:

  • Reusing slot_idx as the seq_id instead of inventing a parallel identifier. The whole forward pass already indexes by it, so the protocol just needed to express it. The defensive fail-fast checks (slot capacity bail, claim_slot ordering invariant, unallocated-slot bail) are the right call for a wire protocol.
  • The two head-side fixes the protocol exposed are the subtle ones and you caught both: staging each sequence's logits row so decode_batch_dispatch does not overwrite row 0 per seq, and interleaving the seq_id broadcasts inside the dispatch so the head's comm-stream op order stays in step with the worker's per-layer all-reduces (the NCCL deadlock).
  • The MoE dispatch decision is the real engineering. Reviving the dead batched path with the prefill grouped-GEMM being a ~60% regression at decode batch sizes, because per-expert M is about 1 and the sort/permute/ptr-table overhead dominates across 36 SSM layers, is a non-obvious trap, and fused forward_k2/k3 for N=2/3 plus per-token for N>=4 is the right shape.
  • ADR-0011 is excellent and unusually honest: documenting that batching gives no aggregate tok/s win at low N (the step is bandwidth and inter-node-NCCL bound), that the real wins are tail latency and admission, and that CUDA graphs do nothing here because launch is not the bottleneck. That framing will save the next person a lot of time.

Before I merge I want to close the one thing static analysis cannot prove: live multi-sequence correctness on 2x GB10 EP=2 with Qwen3.5-122B-A10B-NVFP4 and ATLAS_EP_PROTOCOL=v2. Specifically cross-sequence isolation at N=2 and N=4 (no contamination between concurrent requests), the MTP verify-broadcast path under multiple slots, and no NCCL stall under the interleaved broadcasts. I am queuing that run. Note there is a known pre-existing crash at concurrency >= 4 with deep (~8k) context on the SSM hybrid path that is unrelated to this PR, so I will treat an N=4 deep-context failure as a separate issue rather than a regression here. Will report the numbers back on this thread.

@tbraun96

tbraun96 commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Validated live on 2x GB10 (EP=2) with Sehyo/Qwen3.5-122B-A10B-NVFP4, on an image built from current main + this PR. @camerono this holds up.

What I ran and saw:

  • Gate lifts correctly: with ATLAS_EP_PROTOCOL=v2 and --max-batch-size 4 the head logs 'EP v2 active: honoring max_batch_size=4' (v1 still logs 'forcing max_batch_size=1'), the SSM state pool comes up with 4 slots, and the scheduler starts in batched mode max_batch=4.
  • NCCL inter-node bring-up is clean (rank 0/2 over enp1s0f0np0, 2-rank send/recv), no deadlock from the interleaved seq-id broadcasts.
  • Cross-sequence isolation: fired distinct concurrent prompts at N=2, N=4, N=8 (capital-of-France, 17+25, opposite-of-hot, sky color, capital-of-Japan, 6x7, symbol-for-water, our-planet). Every response matched its OWN prompt with no contamination from a sibling in the batch. N=2 and N=4 were 8/8 content-correct; at N=8 all eight were also content-correct (the one that looked off was H2O vs my literal 'h2o' check plus thinking-mode, not a batch issue).
  • Both head and worker stayed up across all 14 concurrent requests, throughput ~24 tok/s, TTFT ~370 ms.

Two notes for the record, neither blocking this PR: the trailing role-marker leak (newline-user / newline-assistant) you may notice in the raw text is the separate #40 / #100 turn-termination issue on this template, not an EP=2 batch artifact. And I exercised the non-MTP batched path here; MTP+EP multi-slot is the remaining thing to stress, which ties into #94. The batch>1 protocol itself is correct and isolated, which is the core of this PR, so merging. Excellent work, and the ADR is a genuinely useful writeup.

@tbraun96 tbraun96 merged commit 015605f into Avarok-Cybersecurity:main Jun 3, 2026
10 of 12 checks passed
tbraun96 added a commit that referenced this pull request Jun 3, 2026
…panic (#118)

#101's ep_broadcast_seq_and_cmd (reached via ep_broadcast_cmd_for_seq from
the head's prefill / decode / mtp / lifecycle paths) called ep_broadcast_u32
unconditionally. On single-GPU self.comm is None, so the first generation
panicked 'ep_broadcast_u32 without comm' at impl_a2.rs and killed the
scheduler worker thread, bricking the server.

The original ep_broadcast_cmd no-ops on single-GPU via
ep_broadcast_cmd_dispatch's 'comm.is_some() && ep_world_size > 1' guard;
the per-seq variant added for the slot-mux protocol lost that guard. Add
the identical guard so all ep_broadcast_cmd_for_seq sites no-op when EP is
not active. EP=2 is unaffected (guard passes when comm is present and
ep_world_size > 1). Regression from #101, which only exercised the 2-rank
path; caught by live single-GPU testing.

Co-authored-by: Azeez Ishaqui <debaterishaqui@gmail.com>
@tbraun96 tbraun96 mentioned this pull request Jun 3, 2026
tscholak added a commit to tscholak/atlas that referenced this pull request Jun 5, 2026


heim/main is a principled fork of Avarok-Cybersecurity/atlas: posthoc
model-output sanitization has been eliminated end-to-end and xgrammar 2
is the sole intervention surface (see docs/atlas-chat-pipeline.md).
Upstream commits that would re-introduce posthoc cleanup are rejected
by design.

This `-s ours` merge records that every commit in upstream/main up to
and including 7a903aa has been triaged. The working tree is NOT
modified — content was selectively imported by cherry-pick or rewritten
to fit the chat_fsm architecture; this merge is a watermark so the next
sync measures only NEW upstream activity.

────────────────────────────────────────────────────────────────────
Audit of 28 upstream-only commits (heim/main..upstream/main @ 7a903aa):

PORTED (11) — cherry-picked clean or reconciled
  78659fcadee702  chore(deps): criterion 0.5 → 0.8
  abfa93f8ed859b  spark-server: respect disabled MiniMax thinking
  359a2ff753f66b  fix(decode): never synchronize during CUDA graph capture in debug dumps
  d0794b025eeeee  fix(prefix-cache): gate two-phase SSM prefill snapshot insert on vision pad
  8d11605ea0127c  fix(grammar): make qwen3_coder auto-mode tool triggers prefix-free
  c4ccf506f3814f  chore(deps): tokio 1.50 → 1.52, minijinja 2.19 → 2.20
  a248fcff6467a3  spark-server: respect bare-json auto tool choice
  8e933301a0734c  fix(api): max_thinking_tokens alias for thinking_token_budget
  1773b6c9174ef2  chore(deps): prometheus 0.13.4 → 0.14.0
  ddc7080dd0d740  feat(stream): vLLM-compatible return_token_ids (chat_fsm reconciliation)
  87b7bb31c43918  spark-server: --default-chat-template-kwargs CLI flag

REJECTED (3) — violate the posthoc-elimination principle
  40bcf92  cancel scheduler when a loop guard fires
             depends on SimHashLoopGuard / check_loop_watchdog (deleted Stage 2 B6)
  9ada9be  avoid orphan tool messages in F32
             F-coded failure system deleted in Stage 2 B3; injects <system-reminder>
  9d19c32  finish_reason=length when tool-loop guard caps
             same loop-guard dependency as 40bcf92

SKIPPED (12) — not applicable to single-GPU GB10 Qwen3.6-FP8 deployment
  c85949f  per-tensor FP8 dispatch for MIXED_PRECISION (Nvidia variant; we use Qwen-official block-scaled)
  51490f9  route MTP verify through HSS orchestrator (we use neither MTP nor HSS)
  700aea5  guard per-seq broadcast helper (function not in our tree — PR Avarok-Cybersecurity#101 EP v2 never merged)
  3dfa8c9  FP32 MiniMax ModelOpt dense weights (not our model)
  015605f  max_batch_size=1 under EP=2 multi-GPU (single-GPU)
  6708d4e  Qwen3.5-397B NVFP4 loader (not our model)
  9a5edb9  --hosts flag on sparkrun docs (different docs site)
  fd2ad74  Discord invite refresh (different community)
  c741e3c  per-model Dockerfile repairs (no Docker per project policy)
  be740f7  Qwen3.5-397B NVFP4 EP=4 kernel target (not our model)
  92f0da9  CLA workflow signature handling (different repo / signers)
  1b6feff  CLA Assistant workflow (different repo / signers)

DEFERRED (2) — Stage 7 review required
  7a903aa  free dead BF16 intermediates in dense loader
             second hunk depends on e146bef's SSM FP8 prefill block
  e146bef  qwen3.6-mtp dense + reasoning-parser rewrite + numerous fixes (Avarok-Cybersecurity#63)
             375-file 42K-insertion mega-commit; mixes principled work with
             suspect material (digit-normalised content-loop watchdog,
             post-think content cap citing F1-F5). Needs line-by-line review.

Persistent ledger for future syncs lives in
~/.claude/projects/-Users-tscholak-Projects-heim/memory/
reference_atlas_upstream_sync.md

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
tbraun96 added a commit that referenced this pull request Jun 8, 2026
Brings the 26 mainline commits (PRs #88/#89/#99/#101/#102/#103/#105/#106/
#107/#113/#114/#115/#116/#117/#118/#120/#121 and the return-token-ids
streaming feature) onto the campaign branch. 15 conflicts resolved,
preserving the branch's verified work while integrating main's:

Kept (campaign / branch intent):
- EOS-escape default-ON bake + the byte-boundary !cleaned.is_empty() fix
  (handle_token.rs), the in-think tool-call leak scanner state
  (reasoning_xml_*), and the soft/hard tool-validation passthrough
  (tool_handlers.rs).
- #6097910 removal of the always-on F-guard / failures machinery:
  apply_failure_guards (chat_phases) and the F44 streaming circuit-breaker
  stay removed; api/failures/circuit{,_tests}.rs stay deleted (no callers).
- #227 FP32-residual removal: residual_elem hardcoded to BF16 (2 bytes) in
  ffn.rs and the SSM multi-seq decode (use_fp32_residual is gone).
- ATLAS_TOOL_SHORT_TRIGGER kill-switch in the qwen3_coder trigger builder,
  and the 3-arg compile_qwen3_coder_tool_grammar (value_close) signature.

Integrated (from main):
- return-token-ids streaming: pending_token_ids state field +
  .with_token_ids(state.take_ids_if(..)) on emitted chunks + token_ids module.
- #88 prefix-free shared-name tool triggers (the seen-dedup) folded into the
  short-trigger branch, plus the #88 regression test (adapted to 3-arg).
- main's refactored SSM multi-seq decode (decode_multi_seq_inner) carrying the
  concurrent-decode stride fix, on the new norm-op signatures.

Tool-call salvage block from main was NOT taken: the crate::tool_salvage
module does not exist on this branch, so the block would not compile.

Verified: cargo check + cargo check --tests (spark-server, spark-model) green.

Co-Authored-By: Azeez Ishaqui <debaterishaqui@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants